{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "\n",
    "sys.path.append('../../')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Init model and load weights\n",
      "Model ready\n",
      "Globbed 12 wav files.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████| 1/1 [00:00<00:00, 11.08it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "wav.shape: (97681,)\n",
      "acoustic_token: tensor([[[ 11, 591, 281, 629],\n",
      "         [733, 591, 401, 139],\n",
      "         [500, 591, 733, 600],\n",
      "         ...,\n",
      "         [733, 591, 451, 346],\n",
      "         [733, 591, 401, 139],\n",
      "         [386, 591, 281, 461]]], device='cuda:0')\n",
      "acoustic_token.shape: torch.Size([1, 305, 4])\n",
      "acoustic_token.dtype: torch.int64\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "import glob\n",
    "import json\n",
    "import os\n",
    "from pathlib import Path\n",
    "\n",
    "import librosa\n",
    "import torch\n",
    "from academicodec.models.hificodec.vqvae import VQVAE\n",
    "from librosa.util import normalize\n",
    "from tqdm import tqdm\n",
    "\n",
    "ckpt_path = './checkpoint/HiFi-Codec-24k-320d'\n",
    "config_path = './config_24k_320d.json'\n",
    "with open(config_path, 'r') as f:\n",
    "    config = json.load(f)\n",
    "    sample_rate = config['sampling_rate']\n",
    "\n",
    "outputdir = './output'\n",
    "inputdir = './test_wav'\n",
    "num = 1024\n",
    "\n",
    "if __name__ == '__main__':\n",
    "    Path(outputdir).mkdir(parents=True, exist_ok=True)\n",
    "    print(\"Init model and load weights\")\n",
    "    # make sure you downloaded the weights from https://huggingface.co/Dongchao/AcademiCodec/blob/main/HiFi-Codec-24k-320d \n",
    "    # and put it in ./checkpoint/\n",
    "    model = VQVAE(\n",
    "        config_path,\n",
    "        ckpt_path,\n",
    "        with_encoder=True)\n",
    "    model.cuda()\n",
    "    model.eval()\n",
    "    print(\"Model ready\")\n",
    "\n",
    "    wav_paths = glob.glob(f\"{inputdir}/*.wav\")[:num]\n",
    "    print(f\"Globbed {len(wav_paths)} wav files.\")\n",
    "    fid_to_acoustic_token = {}\n",
    "    for wav_path in tqdm(wav_paths[:1]):\n",
    "        wav, sr = librosa.load(wav_path, sr=sample_rate)\n",
    "        print(\"wav.shape:\",wav.shape)\n",
    "        assert sr == sample_rate\n",
    "        fid = os.path.basename(wav_path)[:-4]\n",
    "        wav = normalize(wav) * 0.95\n",
    "        wav = torch.FloatTensor(wav).unsqueeze(0)\n",
    "        wav = wav.to(torch.device('cuda'))\n",
    "        acoustic_token = model.encode(wav)\n",
    "        print(\"acoustic_token:\",acoustic_token)\n",
    "        print(\"acoustic_token.shape:\",acoustic_token.shape)\n",
    "        print(\"acoustic_token.dtype:\",acoustic_token.dtype)\n",
    "        fid = os.path.basename(wav_path)[:-4]\n",
    "        fid_to_acoustic_token[fid] = acoustic_token\n",
    "\n",
    "    torch.save(fid_to_acoustic_token,\n",
    "               os.path.join(outputdir, 'fid_to_acoustic_token.pth'))\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
